Semesterprojekt im Studiengang Mobile Medien an der Hochschule der Medien.
Projektmitglieder:
Nico Burkart (Kürzel: nb094, Matrikel Nr. 34924),
Jonas Wolfram (Kürzel: jw132, Matrikel Nr. 35589)
Betreuer:
Prof. Dr. Joachim Charzinski
Projektzeitraum:
WS20/21
2.1.1. Muss-Ziele
2.1.2. Kann-Ziele
2.2. Implementation
2.2.1. Vorwort
2.2.2. Imports
2.2.3. Analysieren der Trainingsdaten
2.2.4. Vorbereiten der Daten
2.2.5. Architektur des Generators
2.2.6. Probe des untrainierten Generators
2.2.7. Architektur des Diskriminators
2.2.8. Definition der Fehlerfunktionen
2.2.9. Definition der Optimierer
2.2.10. Definition der Trainingsfunktion
2.2.11. Konstanten
2.2.12. Training
2.2.13. Auswertung
2.3. Probleme und Lösungen 2.4.1. Aufgabenaufteilung
2.4.2. Zeitplanung
2.5. Fazitfrom IPython.display import Image
Image("assets/images/gan_diagram.png")
Ein GAN besteht aus zwei neuronalen Netzen, dem Diskriminator (blau) und dem Generator (rot). Der Generator erstellt aus einem Vektor mit zufälligen Werten (weiß) eine Probe (grau, unten). Der Diskriminator bekommt als Input eine Probe aus den Trainingsdaten (grau, oben) und die Probe des Generators (grau, unten) und klassifiziert die Proben jeweils entweder als "echt" oder als "fake". Auf Basis dieser Klassifizierungen können die Kosten (Discriminator loss und Generator loss) berechnet werden, durch die die Netzwerke optimiert werden können.
Das Ziel des Generators ist Bilder zu erstellen, die der Diskriminator als "echt" klassifiziert. Der Diskriminator versucht sich dahingehend zu optimieren, dass er die Proben (grau) immer richtig klassifiziert. Beim Trainieren eines GANs versucht man das Gleichgewicht beider Modelle herzustellen. Wenn dies gewährleistet wird, lernt der Diskriminator die Merkmale der Trainingsdaten und gibt dem Generator auf Basis des Gelernten Feedback zu den generierten Proben, wodurch sich der Generator entsprechend anpasst und das Generierte immer mehr Ähnlichkeit zu den Trainingsdaten bekommmt. Irgendwann sind die generierten Proben nicht mehr zu unterscheiden von den echten Daten.
In unserem Projekt verwenden wir eine Art eines GANs namens DCGAN (Deep Convolutional Generative Adversarial Network). Bei einem DCGAN sind beide neuronalen Netze CNNs (Convolutional Neural Networks), welche sich vor allem eignen, wenn der Datensatz aus Bildern besteht. Der Vorteil von CNNs gegenüber herkömmlichen neuronalen Netzen ist, dass diese Merkmale auch erkennen, wenn z.B. die Position des Merkmals unter den verschiedenen Bildern variiert. Dementsprechend eignet sich diese Art von GAN für unseren Anwendungsfall (i.e. Neue Pokémon generieren welche Merkmale von bereits existierenden haben) sehr gut. Die Idee für diese Art des GANs stammt aus dem Paper "Unsupervised representation learning with deep convolutional generative adversarial networks" (Alec Radford et al., 2016, https://arxiv.org/pdf/1511.06434.pdf), auf dem die Umsetzung dieses Projekt größtenteils beruht.
Zuerst müssen wir die notwendigen Libraries und Module für die Implementierung eines GANs einbinden.
Für die Erstellung und das Training des GANs verwenden wir die Library tensorflow.keras. Ansonsten werden herkömmliche Machine Learning Libraries eingebunden (z.B. numpy). Zusätzlich binden wir noch eigene Module ein (ImageController, Evaluator, SaveController), welche sich im Hintergrund um das Anzeigen von Bildern, von Diagrammen und das Speichern der Modelle kümmern.
import tensorflow as tf
from tensorflow.keras.layers import Dense, Activation, ReLU, LeakyReLU, Reshape, Conv2DTranspose, Conv2D, BatchNormalization, Dropout, Flatten, InputLayer
from tensorflow.keras.models import Sequential
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy
from tensorflow.keras.initializers import RandomNormal
import numpy as np
import os
import ImageController, Evaluator, SaveController
import time
Im nächsten Schritt machen wir uns mit den Trainingsdaten vertraut, indem wir uns ein paar Beispiele anschauen. Für unser GAN verwenden wir ausschließlich Pokémon Portraits. Wir vermuten, dass dadurch eine größere Schnittmenge an Features unter den Bildern vorkommt, wodurch es dem GAN einfacher fällt Ähnliches zu generieren. (Datensatz: https://www.kaggle.com/brilja/pokemon-mugshots-from-super-mystery-dungeon)
image_dir_url = 'data'
ImageController.plot_random_images(image_dir_url)
Als nächstes laden wir nun alle Bilder als Vektoren und müssen, um diese Daten nun optimal benutzen zu können, die Daten noch skalieren. Sie werden von herkömmlichen RGB Werten 0 - 255 auf Werte von -1 - 1 skaliert, damit sie die gleiche Form haben wie der Output des Generators.
train_images = ImageController.load_all_images(image_dir_url).astype('float32')
# Images should be normalized to a range of -1 to 1 in order to match the output of the generator.
train_images = (train_images - 127.5) / 127.5
BUFFER_SIZE = len(train_images)
BATCH_SIZE = 128
# This creates an iterator object which serves batches of size BATCH_SIZE consisting of random normalized images once iterated over.
train_dataset = tf.data.Dataset.from_tensor_slices(tensors=train_images).shuffle(buffer_size=BUFFER_SIZE).batch(batch_size=BATCH_SIZE, drop_remainder=True)
print(f'The shape of the dataset:\n{train_images.shape}')
Der Datensatz besteht aus 4881 Pokémon Bildern mit einer Breite von 64 Pixeln, einer Höhe von 64 Pixeln und 3 Kanälen (RGB).
Image("assets/images/1_rdXKdyfNjorzP10ZA3yNmQ.png")
Der Aufbau und die Eigenschaften des Generators wurden größtenteils aus dem oben genannten Paper übernommen.
Folgende Eigenschaften wurden übernommen:
# The kernel initializer defines how the weights in a neural network layer should be initialized. This has a direct effect on convergence speed. A recommended Distribution for training DCGANs is the following:
generator_kernel_initializer = RandomNormal(stddev=0.02)
def create_generator(noise_dimensions):
model = Sequential()
model.add(InputLayer(input_shape=(noise_dimensions,)))
# First layer has to be a dense layer which expands the noise vector. This has to be done to create the required 16.384 values which after being reshaped form the input feature maps for the first deconvolutional layer.
model.add(Dense(units=4 * 4 * 1024, kernel_initializer=generator_kernel_initializer, use_bias=False))
model.add(BatchNormalization())
model.add(Reshape(target_shape=(4, 4, 1024)))
# The output feature maps of the first deconvolutional layer have a shape of (8, 8, 512). This could be achieved by the stride set to 2 and padding set to 'same' which basically doubles the size of the image (i.e. the pixels).
model.add(Conv2DTranspose(filters=512, kernel_size=(5, 5), kernel_initializer=generator_kernel_initializer, strides=(2, 2), padding='same', use_bias=False))
model.add(BatchNormalization())
model.add(ReLU())
# After the second one there are 256 feature maps with a size of 16x16 left.
model.add(Conv2DTranspose(filters=256, kernel_size=(5, 5), kernel_initializer=generator_kernel_initializer, strides=(2, 2), padding='same', use_bias=False))
model.add(BatchNormalization())
model.add(ReLU())
# After the third block there are only 128 feature maps with a size of 32x32 left.
model.add(Conv2DTranspose(filters=128, kernel_size=(5, 5), kernel_initializer=generator_kernel_initializer, strides=(2, 2), padding='same', use_bias=False))
model.add(BatchNormalization())
model.add(ReLU())
# After the last deconvolutional layer the output has a size of 64x64x3 (RGB image with a width and height of 64px).
model.add(Conv2DTranspose(filters=3, kernel_size=(5, 5), kernel_initializer=generator_kernel_initializer, strides=(2, 2), padding='same', use_bias=False))
# Tanh activation function is used to scale the values between -1 and 1 (same range as the preprocessed real images).
model.add(Activation(activation='tanh'))
assert model.output_shape == (None, 64, 64, 3)
return model
# Definition of the dimensionality of the noise vector:
NOISE_DIMENSIONS = 100
generator = create_generator(NOISE_DIMENSIONS)
print('The topography of the generator\'s model:\n')
generator.summary()
In der folgenden Zelle werden probeweise aus 16 100-dimensionalen Vektoren 16 Bilder durch den untrainierten Generator erstellt.
# Definition of how many sample images should be plotted and saved every epoch to monitor the progress:
EXAMPLES_TO_GENERATE = 16
# A seed is set so that each time the program is restarted the same vector is created. This is essential for reproducability.
tf.random.set_seed(0)
# The seed is a pseudo-random vector used throughout the training to generate images and compare the generator's output between different training epochs.
seed = tf.random.normal(shape=[EXAMPLES_TO_GENERATE, NOISE_DIMENSIONS])
print(f'The shape of the seed:\n{seed.shape}')
ImageController.generate_and_plot_images(generator, seed)
Man kann erkennen, dass die Bilder grau sind. Dies ist logisch, da derzeit alle trainierbaren Parameter normalverteilt sind. Die Werte liegen um den Mittelwert 0. Nachdem sie nun auf die Range 0 bis 255 skaliert werden (Standardwerte für Pixel), liegen alle in etwa bei der Hälfte (i.e. 128). Dadurch entsteht der graue Farbton.
Image("assets/images/dcgan.png")
Der Diskriminator ist im DCGAN das Spiegelbild des Generators. Die gesamte Architektur des Generators wird gespiegelt. Für den Diskriminator haben wir ebenso größtenteils die Empfehlungen von Alec Redford et al. übernommen, sowie noch einige Anpassungen vorgenommen.
Übernommene Eigenschaften:
Zusätzlich implementierte Ansätze:
# The kernel initializer defines how the weights in a neural network layer should be initialized. This has a direct effect on convergence speed. A recommended Distribution for training DCGANs is the following:
discriminator_kernel_initializer = RandomNormal(stddev=0.02)
def create_discriminator():
model = Sequential()
# RGB image with a width of 64 and a height of 64
model.add(InputLayer(input_shape=(64, 64, 3)))
# The output of the first convolutional layer is 128 feature maps with a width of 32 and a height of 32 due to stride set to 2 and padding set to 'same'.
model.add(Conv2D(filters=128, kernel_size=(5,5), kernel_initializer=discriminator_kernel_initializer, strides=(2, 2), padding='same', use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(rate=0.3))
# The second one produces 256 feature maps which are 16x16.
model.add(Conv2D(filters=256, kernel_size=(5,5), kernel_initializer=discriminator_kernel_initializer, strides=(2, 2), padding='same', use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(rate=0.3))
# The third one produces 512 with a shape of (8, 8).
model.add(Conv2D(filters=512, kernel_size=(5,5), kernel_initializer=discriminator_kernel_initializer, strides=(2, 2), padding='same', use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(rate=0.3))
# The last convolutional layer has an output of (4, 4, 1024).
model.add(Conv2D(filters=1024, kernel_size=(5,5), kernel_initializer=discriminator_kernel_initializer, strides=(2, 2), padding='same', use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(rate=0.3))
# Matrices have to be flattened to a vector to be able to add the output layer which is fully connected.
model.add(Flatten())
model.add(Dense(units=1, kernel_initializer=discriminator_kernel_initializer, use_bias=False))
model.add(BatchNormalization())
# Sigmoid function is used to scale the output to a range of 0 to 1 which is the probability of an image being fake or real.
model.add(Activation(activation='sigmoid'))
return model
discriminator = create_discriminator()
print('The topography of the discriminator\'s model:\n')
discriminator.summary()
Die Fehlerfunktionen werden verwendet, um die Qualität der Netze zu messen. Auf Basis der Fehlerfunktionen können die Gewichtsanpassungen der Netze durchgeführt werden.
Die verwendete Fehlerfunktion ist die Binäre Kreuzentropie.
Die Fehlerfunktion des Diskriminators ist die Summe aus dem Wert der Fehlerfunktion auf den echten Daten und dem Wert der Fehlerfunktion auf den generierten Daten. Der Diskriminator wird im Laufe des Trainings versuchen diese Summe zu minimieren. Die Fehlerfunktion des Generators ist die Fehlerfunktion des Diskriminators, angewandt auf den generierten Bildern, abgezogen von 1. Der Generator versucht diese Differenz ebenso zu minimieren. Dadurch wird gewährleistet, dass der Generator einen geringen Fehler Wert hat, wenn der Diskriminator die generierten Bilder als "echt" klassifiert (i.e. echt = 1).
Ein typisches Problem beim Trainieren ist, dass der Diskriminator wesentlich schneller lernt, als der Generator. Dadurch entsteht das Problem, dass der Generator nicht mehr hilfreiches Feedback durch das Ergebnis der Fehlerfunktion bekommt, da dies durch den sehr guten Diskriminator immer maximal ist, auch wenn der Generator eventuell gerade ein gutes Feature gelernt hat. Um dies zu verhindern wird empfohlen Label Smoothing und Label Flipping anzuwenden (siehe https://github.com/soumith/ganhacks). Label Smoothing bewirkt, dass die Labels nicht mehr statische Werte (entweder 0 oder 1 haben), sondern eine Range von bspw. 0 bis 0.3 und 0.7 bis 1 respektive haben. Label Flipping fügt den Labels Noise hinzu. Genauer gesagt werden p-Prozent der Labels umgedreht (aus 0 wird 1, aus 1 wird 0). Im Folgenden werden diese Fehlerfunktionen definiert und auf Seiten des Diskriminators Label Smoothing und Label Flipping angewandt.
# Binary crossentropy is used for binary classification problems (i.e. real vs. fake).
binary_crossentropy = BinaryCrossentropy()
# The loss of the discriminator consists of the sum of the loss of the real image vs. its prediction and the loss of the fake image vs. its prediction.
def calc_discriminator_loss(real_output, fake_output):
y_true_positive = np.ones_like(real_output).flatten()
# Noise is added to the true values to prevent overfitting
y_true_positive = flip_labels(y_true_positive, 0.05)
# Labels are smoothed out to prevent overfitting
y_true_positive = smooth_out_labels(y_true_positive, 0.3)
y_true_positive = np.reshape(y_true_positive, (-1, 1))
y_true_negative = np.zeros_like(fake_output).flatten()
# Noise is added to the true values to prevent overfitting
y_true_negative = flip_labels(y_true_negative, 0.05)
# Labels are smoothed out to prevent overfitting
y_true_negative = smooth_out_labels(y_true_negative, 0.3)
y_true_negative = np.reshape(y_true_negative, (-1, 1))
real_loss = binary_crossentropy(y_true=y_true_positive, y_pred=real_output)
fake_loss = binary_crossentropy(y_true=y_true_negative, y_pred=fake_output)
return real_loss, fake_loss
# If the generator does a good job the discriminator will classify the image as real (i.e. as 1). Therefore for generator loss y_true is 1.
def calc_generator_loss(fake_output):
return binary_crossentropy(y_true=tf.ones_like(fake_output), y_pred=fake_output)
# This function makes the true values of the discriminator (i.e. 0 or 1) dynamic (Randomly between 0 and smoothness_range or smoothness_range and 1). This is really useful to ensure that the discriminator is not overconfident with its predictions.
def smooth_out_labels(labels, smoothness_range):
for i, label in enumerate(labels):
if label == 0:
labels[i] = np.random.random() * smoothness_range
elif label == 1:
labels[i] = 1 - np.random.random() * smoothness_range
return labels
# With probability flip_probability noise is added to the true values. This is really useful to ensure that the discriminator is not overconfident with its predictions.
def flip_labels(labels, flip_probability):
for i, label in enumerate(labels):
if flip_probability > np.random.random():
labels[i] = 0 if label == 1 else 1
return labels
Die Optimierer werden verwendet, um die Gewichtanpassungen abängig von den Ergebnissen der Fehlerfunktionen jede Iteration umzusetzen.
Als Optimierer wird in beiden Fällen der Adam Optimierer verwendet. Empfohlen wird hierbei für beide Netze eine Lernrate von 0.0002 und Momentum von 0.5. In unserem Fall zeigt sich jedoch, dass eine Verringerung der Lernrate des Diskriminators einen positiven Effekt auf das Training hat, da so der Diskriminator weiter verlangsamt wird und sich der Generator besser anpassen kann.
# Adam optimizer with a momentum of 0.5 and a learning rate of 0.0002 is recommended for the training of DCGANs.
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5)
Im Folgenden werden alle zuvor definierten Funktionen nacheinander aufgerufen. Dies wird in einer Schleife so lange gemacht bis das Training zu Ende ist. Das Ende des Trainings wird dabei durch den Hyperparameter end_epoch (sieht weiter unten) festgelegt. Ein Epoch besteht dabei aus i Iterationen. Die Anzahl der Iterationen innerhalb eines Epochs wird wiederum durch die definierte batch size b und die Anzahl aller Trainingsbilder n festgelegt (i=n/b). Die Anzahl der Trainingsbilder ist 4881 und die batch size wurde oben mit einem Wert von 128 defininert, wodurch die Ierationen pro Epoch 39 sind.
Der Ablauf einer Iteration ist nun:
def train_iteration(images_batch):
noise = tf.random.normal(shape=[BATCH_SIZE, NOISE_DIMENSIONS])
with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
generated_images_batch = generator(noise, training=True)
real_output = discriminator(images_batch)
fake_output = discriminator(generated_images_batch)
generator_loss = calc_generator_loss(fake_output)
discriminator_real_loss, discriminator_fake_loss = calc_discriminator_loss(real_output, fake_output)
discriminator_total_loss = discriminator_real_loss + discriminator_fake_loss
discriminator_gradients = discriminator_tape.gradient(target=discriminator_total_loss, sources=discriminator.trainable_variables)
discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
generator_gradients = generator_tape.gradient(target=generator_loss, sources=generator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
return real_output, fake_output, discriminator_real_loss, discriminator_fake_loss, discriminator_total_loss, generator_loss
def train(image_dataset, epochs, generated_images_dir_url: str, generated_figures_dir_url: str):
training_start = time.time()
print('\nTRAINING STARTED')
# Lists with all the metrics.
generator_losses = []
discriminator_real_losses = []
discriminator_fake_losses = []
discriminator_total_losses = []
discriminator_real_accuracies = []
discriminator_fake_accuracies = []
discriminator_total_accuracies = []
for epoch in epochs:
epoch_start = time.time()
for i, image_batch in enumerate(image_dataset):
real_output, fake_output, discriminator_real_loss, discriminator_fake_loss, discriminator_total_loss, generator_loss = train_iteration(image_batch)
# At each epochs' last iteration the current metrics should be saved for future analysis
if i == len(image_dataset) - 1:
generator_losses.append(generator_loss)
discriminator_real_losses.append(discriminator_real_loss)
discriminator_fake_losses.append(discriminator_fake_loss)
discriminator_total_losses.append(discriminator_total_loss)
y_true_positive = np.ones_like(real_output)
y_true_negative = np.zeros_like(fake_output)
binary_accuracy = BinaryAccuracy()
binary_accuracy.update_state(y_true_positive, real_output)
discriminator_real_accuracy = binary_accuracy.result().numpy()
binary_accuracy.reset_states()
binary_accuracy.update_state(y_true_negative, fake_output)
discriminator_fake_accuracy = binary_accuracy.result().numpy()
binary_accuracy.reset_states()
binary_accuracy.update_state(np.vstack((y_true_positive, y_true_negative)), np.vstack((real_output, fake_output)))
discriminator_total_accuracy = binary_accuracy.result().numpy()
binary_accuracy.reset_states()
discriminator_real_accuracies.append(discriminator_real_accuracy)
discriminator_fake_accuracies.append(discriminator_fake_accuracy)
discriminator_total_accuracies.append(discriminator_total_accuracy)
ImageController.generate_and_save_images(generator, seed, os.path.join(generated_images_dir_url, f'images-at-epoch-{epoch}'))
SaveController.save_checkpoint_weights(epoch, generator, discriminator, generator_checkpoints_dir_url, discriminator_checkpoints_dir_url, generator_losses[-1], discriminator_total_losses[-1])
epoch_end = time.time()
print(f'\nTime for epoch {epoch}: {epoch_end - epoch_start} sec')
print(f'Generator loss: {generator_losses[-1]}; Discriminator real loss: {discriminator_real_losses[-1]}; Discriminator fake loss: {discriminator_fake_losses[-1]}; Discriminator total loss: {discriminator_total_losses[-1]};')
print(f'Discriminator real accuracy: {discriminator_real_accuracies[-1]}; Discriminator fake accuracy: {discriminator_fake_accuracies[-1]}; Discriminator total accuracy: {discriminator_total_accuracies[-1]}')
training_end = time.time()
print(f'\nTotal time: {(training_end - training_start)/60} min\n')
print('TRAINING COMPLETED')
Evaluator.plot_and_save_losses(generator_losses, discriminator_real_losses, discriminator_fake_losses, discriminator_total_losses, epochs, generated_figures_dir_url)
Evaluator.plot_and_save_discriminator_accuracies(discriminator_real_accuracies, discriminator_fake_accuracies, discriminator_total_accuracies, epochs, generated_figures_dir_url)
Nachfolgende Konstanten werden verwendet, um URLs zu definieren und zu steuern, wie das Training ablaufen soll.
end_epoch = 500
# URLs
generator_checkpoints_dir_url = os.path.join('assets', 'checkpoints', 'generator')
discriminator_checkpoints_dir_url = os.path.join('assets', 'checkpoints', 'discriminator')
# Directory which stores all generated media in its subdirectories (i.e. the generated Pokémon from the seed, the gif which shows the progress throughout all epochs and the evaluation figures constructed at the end of the training)
generated_media_dir_url = os.path.join('assets', 'generated-media')
# Output URL for the generator once training is over and should_save_generator is set to True
trained_generator_model_dir_url = '../public'
# Controls whether previously saved checkpoint files should be used or not
should_use_checkpoint_files = False
# Controls whether the generator model should be saved at the end of the training at the specified path for future predictions
should_save_generator = True
Aufrufen der Trainingsfunktion in Abhängigkeit zu den definierten Konstanten.
# Depending on the values of the hyperparameters and possibly training done before different initial values are set.
do_checkpoint_files_exist = SaveController.do_checkpoint_files_exist(generator_checkpoints_dir_url, discriminator_checkpoints_dir_url)
if do_checkpoint_files_exist:
print('Found previously saved checkpoint files')
if should_use_checkpoint_files:
print(f'Will be using previous checkpoint files because should_use_checkpoint_files is set to: {should_use_checkpoint_files}')
last_created_generated_media_subdir_url = SaveController.get_last_created_file_or_subdirectory_url(generated_media_dir_url)
generated_images_dir_url = os.path.join(last_created_generated_media_subdir_url, 'images')
generated_gifs_dir_url = os.path.join(last_created_generated_media_subdir_url, 'gifs')
generated_figures_dir_url = os.path.join(last_created_generated_media_subdir_url, 'figures')
last_completed_epoch = SaveController.load_latest_checkpoint_files(generator, discriminator, generator_checkpoints_dir_url, discriminator_checkpoints_dir_url)
start_epoch = last_completed_epoch + 1
else:
print(f'Will not be using previous checkpoint files because should_use_checkpoint_files is set to: {should_use_checkpoint_files}')
SaveController.delete_checkpoint_files(generator_checkpoints_dir_url, discriminator_checkpoints_dir_url)
print('Previous checkpoint files have been deleted')
generated_images_dir_url, generated_gifs_dir_url, generated_figures_dir_url = SaveController.get_new_generated_media_subdir_urls(generated_media_dir_url)
SaveController.create_generated_media_dirs(generated_images_dir_url, generated_gifs_dir_url, generated_figures_dir_url)
start_epoch = 1
else:
print('Did not find any previously saved checkpoint files')
generated_images_dir_url, generated_gifs_dir_url, generated_figures_dir_url = SaveController.get_new_generated_media_subdir_urls(generated_media_dir_url)
SaveController.create_generated_media_dirs(generated_images_dir_url, generated_gifs_dir_url, generated_figures_dir_url)
SaveController.create_checkpoint_dirs(generator_checkpoints_dir_url, discriminator_checkpoints_dir_url)
start_epoch = 1
print(f'Starting training from epoch {start_epoch}. Remaining epochs: {end_epoch - start_epoch + 1}')
train(train_dataset, range(start_epoch, end_epoch + 1), generated_images_dir_url, generated_figures_dir_url)
# A GIF which shows the whole training progress is generated, saved and plotted after the completion of the training
generated_gif_file_url = os.path.join(generated_gifs_dir_url, 'pokegan.gif')
ImageController.save_gif(generated_images_dir_url, generated_gif_file_url)
ImageController.plot_gif(generated_gif_file_url)